Skip to content

在scaled_dot_product_attention函数中加入bool mask #72927

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 6, 2025

Conversation

Qin-sx
Copy link
Contributor

@Qin-sx Qin-sx commented May 25, 2025

PR Category

User Experience

PR Types

Improvements

Description

在scaled_dot_product_attention函数中加入bool mask

	modified:   python/paddle/nn/functional/flash_attention.py
	modified:   test/legacy_test/test_flash_attention.py
Copy link

paddle-bot bot commented May 25, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label May 25, 2025
Qin-sx added 6 commits May 26, 2025 08:57
	modified:   test/legacy_test/test_flash_attention.py
	modified:   test/legacy_test/test_flash_attention.py
	new file:   test/legacy_test/test_scaled_dot_product_attention.py
	modified:   test/legacy_test/test_scaled_dot_product_attention.py
	modified:   test/legacy_test/test_scaled_dot_product_attention.py
	modified:   python/paddle/nn/functional/flash_attention.py
	modified:   test/legacy_test/test_scaled_dot_product_attention.py
	modified:   test/legacy_test/test_scaled_dot_product_attention.py
@@ -1272,6 +1284,7 @@ def scaled_dot_product_attention(
sdp_func_name = _select_sdp_for_sdpa(
query, key, attn_mask, dropout_p, is_causal
)
attn_mask = _convert_bool_mask_to_float(attn_mask, query.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数的逻辑比较简单,是不是可以直接写到这里来

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,收到,已修改

out_ = attention_naive_with_mask(q_, k_, v_, m)
out.backward()
out_.backward()
np.testing.assert_allclose(out.numpy(), out_, rtol=5e-03, atol=1e-03)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你本地在PaConvert里的sdpa的单测里加一下attn_mask为bool的测试例子,测试一下计算结果是否和pytorch 一致。附一下paconvert测试结果。

然后映射文档也记得修改下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯,已修改。之前用的是python3.8虚拟环境,升级python3.9重装环境花费了一些时间。
PaddlePaddle/PaConvert#586

Qin-sx added 6 commits May 29, 2025 00:01
	modified:   python/paddle/nn/functional/flash_attention.py
	modified:   python/paddle/nn/functional/flash_attention.py
	modified:   test/legacy_test/test_scaled_dot_product_attention.py
	modified:   test/legacy_test/test_scaled_dot_product_attention.py
@Qin-sx
Copy link
Contributor Author

Qin-sx commented Jun 1, 2025

在docker paddlepaddle/paddle:latest-dev-cuda11.8-cudnn8.6-trt8.5-gcc82中用一下命令编译

cmake .. -DPY_VERSION=3.9 -DWITH_GPU=ON -DWITH_TENSORRT=ON -DWITH_TESTING=ON

测试test_quant_linear_fuse_pass可以通过

python ../test/ir/inference/test_quant_linear_fuse_pass.py

grep: warning: GREP_OPTIONS is deprecated; please use an alias or script
RuntimeError: module compiled against ABI version 0x1000009 but this version of numpy is 0x2000000
/paddle/test/ir/inference/auto_scan_test.py:61: HypothesisDeprecationWarning: `Healthcheck.all()` is deprecated; use `list(HealthCheck)` instead.
    The `hypothesis codemod` command-line tool can automatically refactor your code to fix this warning.
  suppress_health_check=hypothesis.HealthCheck.all(),
/paddle/test/ir/inference/auto_scan_test.py:70: HypothesisDeprecationWarning: `Healthcheck.all()` is deprecated; use `list(HealthCheck)` instead.
    The `hypothesis codemod` command-line tool can automatically refactor your code to fix this warning.
  suppress_health_check=hypothesis.HealthCheck.all(),
/paddle/test/ir/inference/auto_scan_test.py:461: HypothesisDeprecationWarning: `Healthcheck.all()` is deprecated; use `list(HealthCheck)` instead.
    The `hypothesis codemod` command-line tool can automatically refactor your code to fix this warning.
  suppress_health_check=hypothesis.HealthCheck.all(),
Sun Jun 01 09:14:22-INFO: Start to running test of <class '__main__.TestQuantLinearFusePass'>
I0601 09:14:21.751210 29451 program_interpreter.cc:257] New Executor is Running.
Sun Jun 01 09:14:23-INFO: Number of Invalid Programs: 0
Sun Jun 01 09:14:23-INFO: Number of Ran Programs: 30
Sun Jun 01 09:14:23-INFO: Number of Ignore Tests: 0
.
----------------------------------------------------------------------
Ran 1 test in 1.812s

OK

@zhwesky2010
Copy link
Contributor

看下CI没过

@Qin-sx
Copy link
Contributor Author

Qin-sx commented Jun 4, 2025

看下CI没过

嗯,可否帮忙看一下test_quant_linear_fuse_pass这个测试,修改的函数应该对这个测试没有影响。而且之前应该还有一个CI跑了几次没跑过,今天重新跑了之后跑过了。
Coverage / Coverage test (pull_request)这个测试不太清楚如何重新启动。

在docker paddlepaddle/paddle:latest-dev-cuda11.8-cudnn8.6-trt8.5-gcc82中用一下命令编译

cmake .. -DPY_VERSION=3.9 -DWITH_GPU=ON -DWITH_TENSORRT=ON -DWITH_TESTING=ON

测试test_quant_linear_fuse_pass可以通过

python ../test/ir/inference/test_quant_linear_fuse_pass.py

grep: warning: GREP_OPTIONS is deprecated; please use an alias or script
RuntimeError: module compiled against ABI version 0x1000009 but this version of numpy is 0x2000000
/paddle/test/ir/inference/auto_scan_test.py:61: HypothesisDeprecationWarning: `Healthcheck.all()` is deprecated; use `list(HealthCheck)` instead.
    The `hypothesis codemod` command-line tool can automatically refactor your code to fix this warning.
  suppress_health_check=hypothesis.HealthCheck.all(),
/paddle/test/ir/inference/auto_scan_test.py:70: HypothesisDeprecationWarning: `Healthcheck.all()` is deprecated; use `list(HealthCheck)` instead.
    The `hypothesis codemod` command-line tool can automatically refactor your code to fix this warning.
  suppress_health_check=hypothesis.HealthCheck.all(),
/paddle/test/ir/inference/auto_scan_test.py:461: HypothesisDeprecationWarning: `Healthcheck.all()` is deprecated; use `list(HealthCheck)` instead.
    The `hypothesis codemod` command-line tool can automatically refactor your code to fix this warning.
  suppress_health_check=hypothesis.HealthCheck.all(),
Sun Jun 01 09:14:22-INFO: Start to running test of <class '__main__.TestQuantLinearFusePass'>
I0601 09:14:21.751210 29451 program_interpreter.cc:257] New Executor is Running.
Sun Jun 01 09:14:23-INFO: Number of Invalid Programs: 0
Sun Jun 01 09:14:23-INFO: Number of Ran Programs: 30
Sun Jun 01 09:14:23-INFO: Number of Ignore Tests: 0
.
----------------------------------------------------------------------
Ran 1 test in 1.812s

OK

Copy link
Contributor

@zhwesky2010 zhwesky2010 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zhwesky2010 zhwesky2010 merged commit caa5621 into PaddlePaddle:develop Jun 6, 2025
48 of 50 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants